{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from math import ceil\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.input_pipeline import get_image_folders\n",
    "from utils.training import train\n",
    "from utils.quantization import optimization_step, quantize, initial_scales\n",
    "\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 1e-4  # learning rate for all possible weights\n",
    "HYPERPARAMETER_T = 0.15  # hyperparameter for quantization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create data iterators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_folder, val_folder = get_image_folders()\n",
    "\n",
    "train_iterator = DataLoader(\n",
    "    train_folder, batch_size=batch_size, num_workers=4,\n",
    "    shuffle=True, pin_memory=True\n",
    ")\n",
    "\n",
    "val_iterator = DataLoader(\n",
    "    val_folder, batch_size=256, num_workers=4,\n",
    "    shuffle=False, pin_memory=True\n",
    ")\n",
    "\n",
    "# number of training samples\n",
    "train_size = len(train_folder.imgs)\n",
    "train_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from get_densenet import get_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model, loss, optimizer = get_model(learning_rate=LEARNING_RATE)\n",
    "\n",
    "# load pretrained model, accuracy ~73%\n",
    "model.load_state_dict(torch.load('../vanilla_densenet_big/model_step5.pytorch_state'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### keep copy of full precision kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# copy almost all full precision kernels of the model\n",
    "all_fp_kernels = [\n",
    "    Variable(kernel.data.clone(), requires_grad=True) \n",
    "    for kernel in optimizer.param_groups[1]['params']\n",
    "]\n",
    "# all_fp_kernels - kernel tensors of all convolutional layers \n",
    "# (with the exception of the first conv layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### initial quantization "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scaling factors for each quantized layer\n",
    "initial_scaling_factors = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# these kernels will be quantized\n",
    "all_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, k_fp in zip(all_kernels, all_fp_kernels):\n",
    "    \n",
    "    # choose initial scaling factors \n",
    "    w_p_initial, w_n_initial = initial_scales(k_fp.data)\n",
    "    initial_scaling_factors += [(w_p_initial, w_n_initial)]\n",
    "    \n",
    "    # do quantization\n",
    "    k.data = quantize(k_fp.data, w_p_initial, w_n_initial, t=HYPERPARAMETER_T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### parameter updaters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimizer for updating only all_fp_kernels\n",
    "optimizer_fp = optim.Adam(all_fp_kernels, lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimizer for updating only scaling factors\n",
    "optimizer_sf = optim.Adam([\n",
    "    Variable(torch.FloatTensor([w_p, w_n]).cuda(), requires_grad=True) \n",
    "    for w_p, w_n in initial_scaling_factors\n",
    "], lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1563"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "\n",
    "class lr_scheduler_list:\n",
    "    \"\"\"ReduceLROnPlateau for a list of optimizers.\"\"\"\n",
    "    def __init__(self, optimizer_list):\n",
    "        self.lr_scheduler_list = [\n",
    "            ReduceLROnPlateau(\n",
    "                optimizer, mode='max', factor=0.1, patience=3, \n",
    "                verbose=True, threshold=0.01, threshold_mode='abs'\n",
    "            ) \n",
    "            for optimizer in optimizer_list\n",
    "        ]\n",
    "    \n",
    "    def step(self, test_accuracy):\n",
    "        for scheduler in self.lr_scheduler_list:\n",
    "            scheduler.step(test_accuracy)\n",
    "\n",
    "n_epochs = 15\n",
    "n_batches = ceil(train_size/batch_size)\n",
    "\n",
    "# total number of batches in the train set\n",
    "n_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  3.605 2.637  0.213 0.369  0.454 0.656  674.036\n",
      "1  2.571 2.173  0.384 0.473  0.670 0.744  668.391\n",
      "2  2.276 2.121  0.444 0.487  0.724 0.751  667.756\n",
      "3  2.116 1.922  0.478 0.520  0.751 0.784  668.168\n",
      "4  1.995 1.934  0.505 0.522  0.772 0.783  668.347\n",
      "5  1.914 1.824  0.522 0.546  0.786 0.802  668.592\n",
      "6  1.843 1.749  0.538 0.567  0.797 0.812  668.261\n",
      "7  1.788 1.845  0.549 0.541  0.806 0.795  668.585\n",
      "8  1.737 1.718  0.561 0.571  0.814 0.812  668.550\n",
      "9  1.693 1.861  0.571 0.547  0.819 0.792  668.989\n",
      "10  1.655 1.689  0.579 0.583  0.825 0.819  668.097\n",
      "11  1.619 1.647  0.587 0.590  0.831 0.824  668.579\n",
      "12  1.586 1.545  0.594 0.616  0.836 0.844  668.487\n",
      "13  1.552 1.699  0.602 0.575  0.842 0.814  668.539\n",
      "14  1.521 1.600  0.609 0.601  0.846 0.833  668.464\n",
      "CPU times: user 2h 56min 43s, sys: 27min 18s, total: 3h 24min 2s\n",
      "Wall time: 2h 47min 11s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "optimizer_list = [optimizer, optimizer_fp, optimizer_sf]\n",
    "\n",
    "def optimization_step_fn(model, loss, x_batch, y_batch):\n",
    "    return optimization_step(\n",
    "        model, loss, x_batch, y_batch, \n",
    "        optimizer_list=optimizer_list,\n",
    "        t=HYPERPARAMETER_T\n",
    "    )\n",
    "all_losses = train(\n",
    "    model, loss, optimization_step_fn,\n",
    "    train_iterator, val_iterator, n_epochs,\n",
    "    lr_scheduler=lr_scheduler_list(optimizer_list)        \n",
    ")\n",
    "# epoch logloss  accuracy    top5_accuracy time  (first value: train, second value: val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# backup\n",
    "model.cpu();\n",
    "torch.save(model.state_dict(), 'model_ternary_quantization.pytorch_state')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Continue training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reduce learning rate\n",
    "for optimizer in optimizer_list:\n",
    "    for group in optimizer.param_groups:\n",
    "        group['lr'] = 1e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "model.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  1.312 1.362  0.660 0.654  0.875 0.866  649.648\n",
      "1  1.243 1.341  0.675 0.661  0.885 0.868  651.109\n",
      "2  1.222 1.366  0.678 0.657  0.889 0.866  651.324\n",
      "3  1.207 1.334  0.682 0.664  0.890 0.870  651.090\n",
      "4  1.195 1.350  0.685 0.659  0.891 0.869  651.020\n",
      "CPU times: user 57min 55s, sys: 8min 45s, total: 1h 6min 41s\n",
      "Wall time: 54min 14s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "def optimization_step_fn(model, loss, x_batch, y_batch):\n",
    "    return optimization_step(\n",
    "        model, loss, x_batch, y_batch, \n",
    "        optimizer_list=optimizer_list,\n",
    "        t=HYPERPARAMETER_T\n",
    "    )\n",
    "all_losses = train(\n",
    "    model, loss, optimization_step_fn,\n",
    "    train_iterator, val_iterator, n_epochs       \n",
    ")\n",
    "# epoch logloss  accuracy    top5_accuracy time  (first value: train, second value: val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.cpu();\n",
    "torch.save(model.state_dict(), 'model_ternary_quantization.pytorch_state')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
